import os
import sys
import torch
import torch.nn as nn
from torch.nn import Sequential, Module, ModuleList
# from model.net.jemwrn import JEMWRN
from model.net.wideresnet import Wide_ResNet
from model.net.hem import HEM
from data.dl_getter import normalization_infos

from model.net.resnet_big_con import SupConResNet, SupCEResNet, ResEnc
from model.net.energy_head import LinearClassifier, CosSim, RecipNorm, \
                                  RecipNonlinearNorm, EnergyProjHead, \
                                  ConProjHead, MCogBase
from model.model_io import load_model, load_pretrained_weights


def get_model(args):
    model =  _init_model(args)
    cHeadDict = {'lin': LinearClassifier,
        'con': ConProjHead,
        'cos': CosSim, 'rc': RecipNorm,
        'rcn': RecipNonlinearNorm,
        'ep': EnergyProjHead}
    if args.eval:
        if args.method == 'finetune':
            head = cHeadDict[args.head](args, model.features_dim, args.n_cls, args.feat_dim)
            model = HEM(model.encoder, head)
            load_model(args, args.load_path, model=model)
        else:
            load_model(args, args.load_path, model=model)
    else:
        if args.method == 'finetune':
            load_model(args, args.load_path, model=model)
            head = cHeadDict[args.head](args, model.features_dim, args.n_cls, args.feat_dim)
            model = HEM(model.encoder, head)
        else:
            pass
    # if args.method in ['finetune', 'evaluate']:
        # load_model(args, args.load_path, model=model)
    # if args.method == 'finetune':
        # head = cHeadDict[args.head](args, model.features_dim, args.n_cls, args.feat_dim)
    #     model = HEM(model.encoder, head)
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
        # model = torch.nn.DataParallel(model)
    model = model.cuda()
    return model


def get_pretrained_model(args):
    ## 현재로썬 finetuning에서만 불림
    model, embed_dim =  _init_model(args)
    model.cuda()
    model.eval()
    # load weights to evaluate
    load_pretrained_weights(model, args.pretrained_weights, \
                                args.checkpoint_key, args.arch, args.patch_size)
    print(f"Model {args.arch} built.")
    linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels)
    linear_classifier = linear_classifier.cuda()
    linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier,
                    find_unused_parameters=True,  device_ids=[args.gpu])
    return model, linear_classifier


def _init_model(args):
    # ============ building network ... ============
    # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
    args.arch = args.arch.replace("deit", "vit")
    num_classes = args.num_labels

    in_norm = None
    if args.in_norm:
        norm_info = normalization_infos[args.dataset]
        in_norm = NormalizeInput(mean=norm_info[0], std=norm_info[1])
    if 'sc' in args.arch:
        backbone = 'resnet34' if '34' in args.arch else 'resnet18'
        model = SupConModel(backbone=backbone,
                            second_stage=True,
                            num_classes=args.n_cls)
        return model
    elif 'con' in args.arch:
        model = SupConResNet(args, name='resnet50')
        return model
    elif 'ce' in args.arch:
        name='resnet50'
        model = SupCEResNet(name='resnet50', num_classes=args.n_cls, head=None)
        return model
    elif 'resnet' in args.arch :
        enc = ResEnc(args.arch)

    elif 'wrn' in args.arch:
        name_parts = args.arch.split('-')
        num_classes = args.num_labels
        depth = int(name_parts[1])
        widen = int(name_parts[2])
        # enc = JEMWRN(depth, widen, num_classes=num_classes, in_channel=args.in_ch,
        #                in_norm=in_norm, debug=args.debug, i_act_u=args.i_act_u)
        enc = Wide_ResNet(depth, widen, norm='batch', dropout_rate=0.,
                          in_norm=in_norm, num_classes=args.n_cls)

    else:
        print(f"Unknow architecture: {args.arch}")
        sys.exit(1)
    cHeadDict = {'lin': LinearClassifier,
        'con': ConProjHead,
        'cos': CosSim, 'rc': RecipNorm,
        'rcn': RecipNonlinearNorm,
        'ep': EnergyProjHead}
    cH = cHeadDict['con'] if args.method == 'finetune' else cHeadDict[args.head]
    head = cH(args, enc.last_dim, args.n_cls, args.feat_dim)
    model = HEM(enc, head)
    return model


class NormalizeInput(Module): ## transform으로 옮길 것

    def __init__(self, mean=(0.4914, 0.4822, 0.4465),
                 std=(0.2023, 0.1994, 0.2010)):
        super().__init__()
        r_mean = 2 * mean -1.
        r_std = 2 * std
        self.register_buffer('r_mean', torch.Tensor(r_mean).reshape(1, -1, 1, 1))
        self.register_buffer('r_std', torch.Tensor(r_std).reshape(1, -1, 1, 1))

    def forward(self, x):
        return (x - self.r_mean) / self.r_std
